# 实验五：
# 对一副图像进行傅立叶变换，显示频谱，取其5，50，80、150的频率，显示图像，加噪声（高斯，椒盐）进行频率域平滑，锐化，观察图像变化
# 步骤：1：对一副图像进行傅立叶变换，显示频谱，取其5，50，80、150的频率，显示图像
#      2：加噪声（高斯，椒盐）进行频率域平滑，锐化，观察图像变化
#
import cv2
import numpy as np
import random
import math
from matplotlib import pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号


def sp_noise(image, prob):
    # 添加椒盐噪声
    # prob:噪声比例
    output = np.zeros(image.shape, np.uint8)
    thres = 1 - prob
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            rdn = random.random()
            if rdn < prob:
                output[i][j] = 0
            elif rdn > thres:
                output[i][j] = 255
            else:
                output[i][j] = image[i][j]
    return output


def gaussian_noise(image, mean=0, var=0.001):
    # 添加高斯噪声
    # mean : 均值
    # var : 方差
    image = np.array(image / 255, dtype=float)
    noise = np.random.normal(mean, var ** 0.5, image.shape)
    out = image + noise
    if out.min() < 0:
        low_clip = -1.
    else:
        low_clip = 0.
    out = np.clip(out, low_clip, 1.0)
    out = np.uint8(out * 255)
    return out


def dft(self):
    img = cv2.cvtColor(self, cv2.COLOR_BGR2GRAY)
    dft = cv2.dft(np.float32(img), flags=cv2.DFT_COMPLEX_OUTPUT)
    dft_shift = np.fft.fftshift(dft)
    # 构建频谱图
    magnitude_spectrum = 20 * np.log(cv2.magnitude(dft_shift[:, :, 0], dft_shift[:, :, 1]))
    # 画图
    plt.subplot(121), plt.imshow(img, cmap='gray')
    plt.title('Input Image'), plt.xticks([]), plt.yticks([])
    plt.subplot(122), plt.imshow(magnitude_spectrum, cmap='gray')
    plt.title('Magnitude Spectrum'), plt.xticks([]), plt.yticks([])
    plt.show()


src = cv2.imread('img/lena.png')
dft(src)


def low_pass_filtering(image, radius):
    """
    低通滤波函数
    :param image: 输入图像
    :param radius: 半径
    :return: 滤波结果
    """
    # 对图像进行傅里叶变换，fft是一个三维数组，fft[:, :, 0]为实数部分，fft[:, :, 1]为虚数部分
    fft = cv2.dft(np.float32(image), flags=cv2.DFT_COMPLEX_OUTPUT)
    # 对fft进行中心化，生成的dshift仍然是一个三维数组
    dshift = np.fft.fftshift(fft)

    # 得到中心像素
    rows, cols = image.shape[:2]
    mid_row, mid_col = int(rows / 2), int(cols / 2)

    # 构建掩模，256位，两个通道
    mask = np.zeros((rows, cols, 2), np.float32)
    mask[mid_row - radius:mid_row + radius, mid_col - radius:mid_col + radius] = 1

    # 给傅里叶变换结果乘掩模
    fft_filtering = dshift * mask
    # 傅里叶逆变换
    ishift = np.fft.ifftshift(fft_filtering)
    image_filtering = cv2.idft(ishift)
    image_filtering = cv2.magnitude(image_filtering[:, :, 0], image_filtering[:, :, 1])
    # 对逆变换结果进行归一化（一般对图像处理的最后一步都要进行归一化，特殊情况除外）
    cv2.normalize(image_filtering, image_filtering, 0, 1, cv2.NORM_MINMAX)
    return image_filtering


def high_pass_filtering(image, radius, n):
    """
    高通滤波函数
    :param image: 输入图像
    :param radius: 半径
    :param n: ButterWorth滤波器阶数
    :return: 滤波结果
    """
    # 对图像进行傅里叶变换，fft是一个三维数组，fft[:, :, 0]为实数部分，fft[:, :, 1]为虚数部分
    fft = cv2.dft(np.float32(image), flags=cv2.DFT_COMPLEX_OUTPUT)
    # 对fft进行中心化，生成的dshift仍然是一个三维数组
    dshift = np.fft.fftshift(fft)

    # 得到中心像素
    rows, cols = image.shape[:2]
    mid_row, mid_col = int(rows / 2), int(cols / 2)

    # 构建ButterWorth高通滤波掩模

    mask = np.zeros((rows, cols, 2), np.float32)
    for i in range(0, rows):
        for j in range(0, cols):
            # 计算(i, j)到中心点的距离
            d = math.sqrt(pow(i - mid_row, 2) + pow(j - mid_col, 2))
            try:
                mask[i, j, 0] = mask[i, j, 1] = 1 / (1 + pow(radius / d, 2 * n))
            except ZeroDivisionError:
                mask[i, j, 0] = mask[i, j, 1] = 0
    # 给傅里叶变换结果乘掩模
    fft_filtering = dshift * mask
    # 傅里叶逆变换
    ishift = np.fft.ifftshift(fft_filtering)
    image_filtering = cv2.idft(ishift)
    image_filtering = cv2.magnitude(image_filtering[:, :, 0], image_filtering[:, :, 1])
    # 对逆变换结果进行归一化（一般对图像处理的最后一步都要进行归一化，特殊情况除外）
    cv2.normalize(image_filtering, image_filtering, 0, 1, cv2.NORM_MINMAX)
    return image_filtering


def bandpass_filter(image, radius, w, n=1):
    """
    带通滤波函数
    :param image: 输入图像
    :param radius: 带中心到频率平面原点的距离
    :param w: 带宽
    :param n: 阶数
    :return: 滤波结果
    """
    # 对图像进行傅里叶变换，fft是一个三维数组，fft[:, :, 0]为实数部分，fft[:, :, 1]为虚数部分
    fft = cv2.dft(np.float32(image), flags=cv2.DFT_COMPLEX_OUTPUT)
    # 对fft进行中心化，生成的dshift仍然是一个三维数组
    dshift = np.fft.fftshift(fft)

    # 得到中心像素
    rows, cols = image.shape[:2]
    mid_row, mid_col = int(rows / 2), int(cols / 2)

    # 构建掩模，256位，两个通道
    mask = np.zeros((rows, cols, 2), np.float32)
    for i in range(0, rows):
        for j in range(0, cols):
            # 计算(i, j)到中心点的距离
            d = math.sqrt(pow(i - mid_row, 2) + pow(j - mid_col, 2))
            if radius - w / 2 < d < radius + w / 2:
                mask[i, j, 0] = mask[i, j, 1] = 1
            else:
                mask[i, j, 0] = mask[i, j, 1] = 0

    # 给傅里叶变换结果乘掩模
    fft_filtering = dshift * np.float32(mask)
    # 傅里叶逆变换
    ishift = np.fft.ifftshift(fft_filtering)
    image_filtering = cv2.idft(ishift)
    image_filtering = cv2.magnitude(image_filtering[:, :, 0], image_filtering[:, :, 1])
    # 对逆变换结果进行归一化（一般对图像处理的最后一步都要进行归一化，特殊情况除外）
    cv2.normalize(image_filtering, image_filtering, 0, 1, cv2.NORM_MINMAX)
    return image_filtering


if __name__ == "__main__":
    image = cv2.imread("img/lena.png", 0)

    image_low_pass_filtering5 = low_pass_filtering(image, 5)
    image_low_pass_filtering50 = low_pass_filtering(image, 50)
    image_low_pass_filtering80 = low_pass_filtering(image, 80)
    image_low_pass_filtering150 = low_pass_filtering(image, 150)
    plt.subplot(151), plt.imshow(image, 'gray'), plt.title("原图"), plt.xticks([]), plt.yticks([])
    plt.subplot(152), plt.imshow(image_low_pass_filtering5, 'gray'), plt.title("5"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(153), plt.imshow(image_low_pass_filtering50, 'gray'), plt.title("50"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(154), plt.imshow(image_low_pass_filtering80, 'gray'), plt.title("80"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(155), plt.imshow(image_low_pass_filtering150, 'gray'), plt.title("150"), plt.xticks(
        []), plt.yticks([])
    plt.show()

    image_gaussian_noise = gaussian_noise(image, 0.1, 0.03)
    image_sp_noise = sp_noise(image, 0.02)

    image_low_pass_filtering1 = low_pass_filtering(image_gaussian_noise, 30)
    image_high_pass_filtering1 = high_pass_filtering(image_gaussian_noise, 5, 1)

    image_low_pass_filtering2 = low_pass_filtering(image_sp_noise, 50)
    image_high_pass_filtering2 = high_pass_filtering(image_sp_noise, 10, 1)
    # plt.subplot(241), plt.imshow(image, 'gray'), plt.title("原图"), plt.xticks([]), plt.yticks([])
    plt.subplot(241), plt.imshow(image, 'gray'), plt.title("原图"), plt.xticks([]), plt.yticks([])
    plt.subplot(242), plt.imshow(image_gaussian_noise, 'gray'), plt.title("高斯噪声"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(243), plt.imshow(image_low_pass_filtering1, 'gray'), plt.title("低通滤波(模糊、平滑)"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(244), plt.imshow(image_high_pass_filtering1, 'gray'), plt.title("高通滤波(锐化)"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(246), plt.imshow(image_sp_noise, 'gray'), plt.title("椒盐噪声"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(247), plt.imshow(image_low_pass_filtering2, 'gray'), plt.title("低通滤波(模糊、平滑)"), plt.xticks(
        []), plt.yticks([])
    plt.subplot(248), plt.imshow(image_high_pass_filtering2, 'gray'), plt.title("高通滤波(锐化)"), plt.xticks([]), plt.yticks(
        [])
    plt.show()
